package crud

import (
	"context"
	"encoding/json"
	"fmt"
	"strconv"
	"strings"

	"gitee.com/git-lz/twelve/common/consts"
	"gitee.com/git-lz/twelve/common/dto/dto_task"
	"gitee.com/git-lz/twelve/common/dto/response"
	"gitee.com/git-lz/twelve/common/merrors"
	"gitee.com/git-lz/twelve/common/utils"
	"gitee.com/git-lz/twelve/dao"
	"gitee.com/git-lz/twelve/dao/basedao"
	"gitee.com/git-lz/twelve/model/tables"
	"gitee.com/git-lz/twelve/model/task"
)

func (t *TaskManager) Create(ctx context.Context, request *dto_task.CreateTaskReq) *response.Response {
	resp := response.NewResponse()

	var headNodeId uint64

	tx := basedao.NewDefaultBaseDao().Begin()
	var dealNodeFunc = func(node string) error {
		nodeInfo, ok := request.Nodes[node]
		if !ok {
			return nil
		}

		nodeElementIds, err := t.getNodeElements(ctx, tx, nodeInfo.Elements)
		if err != nil {
			return err
		}
		nodeElementIdsBytes, err := json.Marshal(nodeElementIds)
		if err != nil {
			return err
		}

		// 创建node
		daoNode := &tables.Node{
			Key:      nodeInfo.Key,
			IsAsync:  nodeInfo.Async,
			Elements: nodeElementIdsBytes,
		}
		if err := tx.SetTableName(dao.TableNameNode).SetModel(daoNode).InsertOneRecord(ctx); err != nil {
			return err
		}

		if node == request.HeadNodeId {
			headNodeId = daoNode.ID
		}
		return nil
	}
	if err := utils.BFS(request.HeadNodeId, request.Edge.FromToNodeIds, dealNodeFunc); err != nil {
		tx.Rollback()
		return resp.WithMsg(merrors.ErrnoNodeRangeFailed, err.Error())
	}

	taskInfo := request.Task.ToDaoTask(headNodeId, consts.TaskStateToAudit)
	if err := tx.SetTableName(dao.TableNameTask).SetModel(taskInfo).InsertOneRecord(ctx); err != nil {
		tx.Rollback()
		return resp.WithMsg(merrors.ErrnoMysqlFailed, err.Error())
	}

	// 保存edge
	var daoEdges []*tables.DagEdge
	for fromNodeIdStr, toNodeIds := range request.Edge.FromToNodeIds {
		fromNodeId, err := strconv.ParseUint(fromNodeIdStr, 10, 64)
		if err != nil {
			tx.Rollback()
			return resp.WithMsg(merrors.ErrnoDataFormatFailed, err.Error())
		}

		daoEdge := &tables.DagEdge{
			TaskId:     taskInfo.ID,
			FromNodeId: fromNodeId,
			ToNodeIds:  strings.Join(toNodeIds, ","),
		}
		daoEdges = append(daoEdges, daoEdge)
	}

	if err := tx.SetTableName(dao.TableNameDagEdge).InsertBatch(ctx, daoEdges, len(daoEdges)); err != nil {
		tx.Rollback()
		return resp.WithMsg(merrors.ErrnoMysqlFailed, err.Error())
	}

	tx.Commit()
	return resp
}

func (t *TaskManager) getNodeElements(ctx context.Context, tx *basedao.BaseDao, elements []task.NodeElement) ([]task.NodeElementId, error) {
	var nodeElementIds []task.NodeElementId
	for _, element := range elements {
		daoCondition := element.ConditionInfo.ToDaoCondition()

		// 创建condition
		if err := tx.SetTableName(dao.TableNameCondition).SetModel(daoCondition).InsertOneRecord(ctx); err != nil {
			return nil, err
		}

		// 创建action
		daoAction, err := element.ActionInfo.ToDaoAction()
		if err != nil {
			return nil, err
		}
		if err := tx.SetTableName(dao.TableNameAction).SetModel(daoAction).InsertOneRecord(ctx); err != nil {
			return nil, err
		}

		nodeElementIds = append(nodeElementIds, task.NodeElementId{
			ActionId:    daoAction.ID,
			ConditionId: daoCondition.ID,
		})
	}

	return nodeElementIds, nil
}

func (t *TaskManager) GetAllTasks(ctx context.Context) *response.Response {
	resp := response.NewResponse()

	var tasks []tables.Task
	if err := dao.NewTaskDao().GetsByCond(ctx, nil, &tasks); err != nil {
		return resp.WithMsg(merrors.ErrnoMysqlFailed, err.Error())
	}

	var res []task.Task
	for _, tk := range tasks {
		// 获取边
		var edges []tables.DagEdge
		if err := dao.NewDagEdgeDao().GetsByCond(ctx, map[string]interface{}{
			"task_id": tk.ID,
			"deleted": 0,
		}, &edges); err != nil {
			return resp.WithMsg(merrors.ErrnoMysqlFailed, err.Error())
		}
		var edgesMap = make(map[string][]string)
		for _, edge := range edges {
			edgesMap[fmt.Sprintf("%d", edge.FromNodeId)] = strings.Split(edge.ToNodeIds, ",")
		}
		edgeInfo := &task.EdgeInfo{
			FromToNodeIds: edgesMap,
		}

		nodes, err := t.getNodesFromDb(ctx, edgesMap, tk.HeadNode)
		if err != nil {
			return resp.WithMsg(merrors.ErrnoMysqlFailed, err.Error())
		}

		res = append(res, task.Task{
			TriggerType:       tk.TriggerType,
			TriggerKey:        tk.TriggerKey,
			Version:           tk.Version,
			BodyPojoClassName: tk.BodyPojoClassName,
			HeadNodeId:        fmt.Sprintf("%d", tk.HeadNode),
			GetResponseMethod: tk.Response,
			Nodes:             nodes,
			Edge:              edgeInfo,
		})
	}

	return resp.WithData(res)
}

func (t *TaskManager) getNodesFromDb(ctx context.Context, edgesMap map[string][]string, headNodeId uint64) (map[string]*task.NodeInfo, error) {
	// 获取边

	nodeMap := make(map[string]*task.NodeInfo)
	// 遍历task图
	var dealEachNodeFunc = func(node string) error {
		nodeId, err := strconv.ParseUint(node, 10, 64)
		if err != nil {
			return err
		}

		// 获取node
		var daoNodeInfo tables.Node
		if err := dao.NewNodeDao().GetById(ctx, nodeId, &daoNodeInfo); err != nil {
			return err
		}

		// 获取node element
		var daoNodeElements []task.NodeElementId
		if err := json.Unmarshal(daoNodeInfo.Elements, &daoNodeElements); err != nil {
			return err
		}

		// 获取action和condition
		var elements []task.NodeElement
		for _, daoNodeElement := range daoNodeElements {
			// 获取action
			var daoActionInfo tables.Action
			if err := dao.NewActionDao().GetById(ctx, daoNodeElement.ActionId, &daoActionInfo); err != nil {
				return err
			}

			// 获取condition
			var daoConditionInfo tables.Condition
			if err := dao.NewConditionDao().GetById(ctx, daoNodeElement.ConditionId, &daoConditionInfo); err != nil {
				return err
			}

			elements = append(elements, task.NodeElement{
				ConditionInfo: &task.ConditionInfo{
					Type:       daoConditionInfo.Type,
					Name:       daoConditionInfo.Name,
					Expression: string(daoConditionInfo.Expression),
				},
				ActionInfo: &task.ActionInfo{
					Type:        daoActionInfo.Type,
					Config:      daoActionInfo.Config,
					Description: daoActionInfo.Description,
				},
			})
		}

		nodeInfo := &task.NodeInfo{
			Key:      daoNodeInfo.Key,
			Async:    daoNodeInfo.IsAsync,
			Elements: elements,
		}

		nodeMap[node] = nodeInfo
		return nil
	}

	if err := utils.BFS(fmt.Sprintf("%d", headNodeId), edgesMap, dealEachNodeFunc); err != nil {
		return nil, err
	}

	return nodeMap, nil
}
